import os
import numpy as np
import yaml
import copy
import random
from torch.utils.data import Dataset
from PIL import Image
from image_synthesis.data.utils.image_path_dataset import ImagePaths
import image_synthesis.data.utils.imagenet_utils as imagenet_utils
from image_synthesis.utils.misc import instantiate_from_config, get_all_file
from image_synthesis.data.utils.util import generate_stroke_mask
from image_synthesis.utils.io import load_dict_from_json


class ImageListDataset(Dataset):
    """
    This class can be used to load images when given a file contain the list of image paths
    """
    def __init__(self, 
                 name,
                 image_list_file,
                 data_root='',
                 coord=False,
                 im_preprocessor_config={
                     'target': 'image_synthesis.data.utils.image_preprocessor.SimplePreprocessor',
                     'params':{
                        'size': 256,
                        'random_crop': True,
                        'horizon_flip': True
                        }
                 },
                 image_end_with='',
                 mask=-1.0, # mask probility
                 stroken_mask_params=None,
                 multi_image_mask=False,
                 return_data_keys=None
                 ):
        super().__init__()
        self.name = name
        self.image_list_file = image_list_file
        
        if data_root != '':
            root = os.path.join(data_root, self.name)
        else:
            root = os.path.join("data", self.name)

        if os.path.isfile(self.image_list_file):
            with open(self.image_list_file, "r") as f:
                relpaths = f.read().splitlines()
                paths = [os.path.join(root, relpath) for relpath in relpaths]
        elif self.image_list_file == '':
            assert image_end_with != ''
            image_end_with = image_end_with.split(',')
            paths = get_all_file(root, end_with=image_end_with)
        else:
            raise NotImplementedError
        self.data = ImagePaths(paths=paths)

        # get preprocessor
        self.preprocessor = instantiate_from_config(im_preprocessor_config)
        self.coord = coord
        self.mask = mask
        self.stroken_mask_params = stroken_mask_params
        self.masks = None # the provided masks
        self.multi_image_mask = multi_image_mask
        self.return_data_keys = return_data_keys

    def __len__(self):
        # return 300 #TODO
        return len(self.data)


    def get_mask(self, hw, index=None):
        if self.masks is not None:
            if index is None or index >= len(self.masks):
                index = random.randint(0, len(self.masks)-1)
            mask = self.masks[index]
        else:
            im_size = hw
            if self.stroken_mask_params is None:
                stroken_mask_params = {
                    'max_parts': 15,
                    'maxVertex': 25,
                    'maxLength': 100, 
                    'maxBrushWidth': 24
                }
            else:
                stroken_mask_params = self.stroken_mask_params
            stroken_mask_params['im_size'] = im_size 
            mask = generate_stroke_mask(**stroken_mask_params)
        if len(mask.shape) == 3:
            mask = mask[:, :, 0:1]
        elif len(mask.shape) == 2:
            mask = mask[:, :, np.newaxis]
        else:
            raise ValueError('Invalide shape of mask:', mask.shape)
        return mask

    def __getitem__(self, index):
        data = self.data[index]
        
        if not self.coord:
            image = self.preprocessor(image=data['image'])['image']
            data['image'] = np.transpose(image.astype(np.float32), (2, 0, 1)) # 3 x H x W
        else:
            h, w, _ = data['image'].shape
            coord = (np.arange(h*w).reshape(h,w,1)/(h*w)).astype(np.float32)
            # import pdb; pdb.set_trace()
            out = self.preprocessor(image=data["image"], coord=coord)
            data['image'] = np.transpose(out["image"].astype(np.float32), (2, 0, 1))
            data["coord"] = np.transpose(out["coord"].astype(np.float32), (2, 0, 1))
        
        if random.random() < self.mask:
            mask = self.get_mask(hw=data['image'].shape[1:3])
            data['mask'] = np.transpose(mask.astype(np.bool), (2, 0, 1)) # 1 x H x W

            if self.multi_image_mask:
                data['image'] = data['image'] * data['mask'].astype(np.float32)
        
        if self.return_data_keys is not None:
            data_out = {}
            for k in self.return_data_keys:
                data_out[k] = data[k]

            return data_out
        else:
            return data

class ImageListImageNet(Dataset):
    """
    This class can be used to load images when given image_list_file
    """
    def __init__(self, 
                 name,
                 image_list_file='',
                 data_root='',
                 im_preprocessor_config={
                     'target': 'image_synthesis.data.utils.image_preprocessor.SimplePreprocessor',
                     'params':{
                        'size': 256,
                        'random_crop': True,
                        'horizon_flip': True
                        }
                 },  
                 file_end_with='.JPEG',
                 ):
        super().__init__()
        self.name = name
        #assert self.name in ['imagenet/train', 'imagenet/val']
        self.image_list_file = image_list_file
        self.file_end_with = file_end_with
        self.data_root = data_root if data_root != '' else 'data'

        # load dataset
        self._load()

        # get preprocessor
        self.preprocessor = instantiate_from_config(im_preprocessor_config)

    def _get_class_info(self, path_to_yaml="data/imagenet_class_to_idx.yaml"):
        with open(path_to_yaml) as f:
            class2id = yaml.full_load(f)
        id2class = {}
        for classn, idx in id2class.items():
            id2class[idx] = classn
        return class2id, id2class

    def _load(self):
        self.class2id, self.id2class = self._get_class_info()

        if os.path.isfile(self.image_list_file):
            with open(self.image_list_file, "r") as f:
                relative_path = f.read().splitlines()
            self.abspaths = [os.path.join(self.data_root, self.name, p) for p in relative_path]
            print('Found {} files with the given {}'.format(len(self.abspaths), self.image_list_file))

        else:
            file_end_with = self.file_end_with.split(',')
            self.abspaths = get_all_file(os.path.join(self.data_root, self.name), end_with=file_end_with)
            print('Found {} files by searching {} with extensions {}'.format(len(self.abspaths), os.path.join(self.data_root, self.name), str(self.file_end_with)))
        
        self.class_labels = [self.class2id[s.split(os.sep)[-2]] for s in self.abspaths]
        labels = {
            "abs_path": self.abspaths,
            "class_id": np.array(self.class_labels).astype(np.int64),
        }
        self.data = ImagePaths(paths=self.abspaths,
                               labels=labels)

    def __len__(self):
        # return 200 #TODO
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]
        image = self.preprocessor(image=data['image'])['image']
        data['image'] = np.transpose(image.astype(np.float32), (2, 0, 1))
        data['class_name'] = imagenet_utils.IMAGENET_CLASSES[int(data['class_id'])]
        data['text'] = imagenet_utils.get_random_text_label(class_names=[data['class_name']])[0]

        return data

class ImageListImageText(Dataset):
    def __init__(    
        self,             
        name,
        image_list_file,
        data_root='',
        load_random_mask=False,
        all_masked=-1.0,
        image_load_size=None, # height, width
        image_to_caption_dir_replace=['images/', 'captions/'],
        image_to_caption_ext_replace=['.jpg', '.txt'],
        im_preprocessor_config={
            'target': 'image_synthesis.data.utils.image_preprocessor.SimplePreprocessor',
            'params':{
            'size': [256, 256],
            'smallest_max_size': 256,
            'random_crop': True,
            'horizon_flip': True
            }
        },
        im_preprocessor_config_hr=None, # It is useful for image completion
        text_tokenizer_config=None,
        # args for image inpainting
        inferior_size=None, # h, w
        inferior_random_degree=2,
        mask_low_to_high=-1.0,
        mask_type=1,
        pixel_kmens_center_path='data/kmeans_centers.npy',
    ):
        self.name = name
        assert self.name in ['cub-200-2011', 'flowers102', 'multi-modal-celeba-hq']

        self.data_root = 'data' if data_root == '' else data_root
        self.image_list_file = image_list_file
        self.load_random_mask = load_random_mask
        self.mask_type = mask_type
        self.all_masked = all_masked
        self.image_to_caption_dir_replace = image_to_caption_dir_replace
        self.image_to_caption_ext_replace = image_to_caption_ext_replace
        self.image_load_size = image_load_size


        self.preprocessor = instantiate_from_config(im_preprocessor_config)
        if im_preprocessor_config_hr is not None:
            raise NotImplementedError
        self.preprocessor_hr = instantiate_from_config(im_preprocessor_config_hr)
        self.text_tokenizer = instantiate_from_config(text_tokenizer_config)


        # for priors
        self.inferior_size = inferior_size
        self.inferior_random_degree = inferior_random_degree
        self.mask_low_to_high = mask_low_to_high
        self.pixel_centers = np.load(pixel_kmens_center_path)
        self.pixel_centers = np.rint(127.5 * (1 + self.pixel_centers)) # map to origin [0-255]

        self._load()
        self._filter_based_on_text()

    def _load(self):
        # get image list
        with open(self.image_list_file, 'r') as f:
            relative_image_path = f.readlines()
            relative_image_path = [p.strip() for p in relative_image_path]
            f.close()
        
        # load captions
        abs_image_path = []
        image_path_to_captions = {}
        for p in relative_image_path:
            image_p = os.path.join(self.data_root, self.name, p)
            abs_image_path.append(image_p)
            p = p.replace(self.image_to_caption_dir_replace[0], self.image_to_caption_dir_replace[1])
            p = p.replace(self.image_to_caption_ext_replace[0], self.image_to_caption_ext_replace[1])
            text_p = os.path.join(self.data_root, self.name, p)
            with open(text_p, 'r') as txt_f:
                caps = txt_f.readlines()
                caps = [c.strip() for c in caps]
                caps = [c for c in caps if len(c) > 5] # for caption, we set it should be more than 5 charaters
                image_path_to_captions[image_p] = caps
                txt_f.close()
        self.abs_image_path = abs_image_path
        self.image_path_to_captions = image_path_to_captions

        # laad box if needed
        if self.name == 'cub-200-2011':
            image_path_to_box = {}
            image_to_box = load_dict_from_json(os.path.join(self.data_root, self.name, 'image_to_box.json'))
            for k in image_to_box:
                image_path_to_box[os.path.join(self.data_root, self.name, k)] = image_to_box[k]
            self.image_path_to_box = image_path_to_box


    def _filter_based_on_text(self):
        if self.text_tokenizer is not None:
            abs_im_path = []
            for im_path in self.abs_image_path:
                captions_ = copy.deepcopy(self.image_path_to_captions[im_path])
                captions = []
                for txt in captions_:
                    txt_token = self.text_tokenizer.get_tokens([txt])[0]
                    valid = self.text_tokenizer.check_length(txt_token)
                    if valid:
                        captions.append(txt)
                if len(captions) > 0:
                    self.image_path_to_captions[im_path] = captions
                    abs_im_path.append(im_path)
            print('Filter data done: from {} to {}'.format(len(self.abs_image_path), len(abs_im_path)))
            self.abs_image_path = abs_im_path

    def load_image(self, im_path, resize=True):
        im = Image.open(im_path)
        if not im.mode == "RGB":
            im = im.convert("RGB")
        
        if hasattr(self, 'image_path_to_box'):
            box = copy.deepcopy(self.image_path_to_box[im_path])
            
            width, height = im.size

            r = int(np.maximum(box[2], box[3]) * 0.75)
            center_x = int((2 * box[0] + box[2]) / 2)
            center_y = int((2 * box[1] + box[3]) / 2)
            y1 = np.maximum(0, center_y - r)
            y2 = np.minimum(height, center_y + r)
            x1 = np.maximum(0, center_x - r)
            x2 = np.minimum(width, center_x + r)
            box = [x1, y1, x2, y2]
            im = im.crop(box)
        
        if self.image_load_size is not None and resize:
            im = im.resize((self.image_load_size[1], self.image_load_size[0]), Image.BILINEAR)
        im = np.array(im).astype(np.uint8)
        return im

    def load_caption(self, im_path):
        # import pdb; pdb.set_trace()
        captions = copy.deepcopy(self.image_path_to_captions[im_path])
        idx = random.randint(0, len(captions)-1)
        caption = captions[idx]
        return caption
    
    def load_mask(self, im):
        if self.mask_type == 1:
            mask = generate_stroke_mask(im_size=[256, 256],
                                        max_parts=15,
                                        maxVertex=25,
                                        maxLength=100, 
                                        maxBrushWidth=24) # H x W x 1
        elif self.mask_type == 2:
            mask = generate_stroke_mask(im_size=[256, 256], #[256, 256],
                                        max_parts=15,
                                        maxVertex=50,
                                        maxLength=100, 
                                        maxBrushWidth=40) # H x W x 1
        else:
            raise NotImplementedError
        if random.random() < self.all_masked:
            mask = mask * 0
        
        if self.inferior_size is not None and random.random() < self.mask_low_to_high:
            h, w = self.inferior_size[0], self.inferior_size[1]
            mask = Image.fromarray(mask[:, :, 0].astype(np.uint8))
            mask = mask.resize((w, h), resample=Image.NEAREST)
            mask = np.array(mask)[:, :, np.newaxis]

        return mask.astype(np.uint8)


    def load_inferior(self, im):
        """
        The inferior is infact the low resolution image, which is also
        be degraded by quantization.
        """
        def squared_euclidean_distance_np(a,b):
            b = b.T
            a2 = np.sum(np.square(a),axis=1)
            b2 = np.sum(np.square(b),axis=0)
            ab = np.matmul(a,b)
            d = a2[:,None] - 2*ab + b2[None,:]
            return d

        def color_quantize_np_topK(x, clusters,K):
            x = x.reshape(-1, 3)
            d = squared_euclidean_distance_np(x, clusters)
            # print(np.argmin(d,axis=1))
            top_K=np.argpartition(d, K, axis=1)[:,:K] 

            h,w=top_K.shape
            select_index=np.random.randint(w,size=(h))
            return top_K[range(h),select_index]

        def inferior_degradation(img,clusters,prior_size,K=1): ## Downsample and random change

            LR_img_cv2=img.resize((prior_size[1], prior_size[0]), resample=Image.BILINEAR)
            LR_img_cv2=np.array(LR_img_cv2)

            token_id=color_quantize_np_topK(LR_img_cv2.astype(clusters.dtype),clusters,K)
            primers = token_id.reshape(-1,prior_size[0]*prior_size[1])
            primers_img = [np.reshape(clusters[s], [prior_size[0],prior_size[1], 3]).astype(np.uint8) for s in primers]

            degraded=Image.fromarray(primers_img[0])

            return degraded ## degraded by inferior cluster 

        h, w = im.shape[0:2]

        inferior = Image.fromarray(im.astype(np.uint8)).convert("RGB")
        inferior = inferior_degradation(inferior, self.pixel_centers, self.inferior_size, K=self.inferior_random_degree)
        # inferior = inferior.resize((w, h),resample=Image.BICUBIC)
        inferior = inferior.resize((w, h),resample=Image.BILINEAR)
        inferior = np.array(inferior).astype(np.uint8)

        return inferior


    def __len__(self):
        return len(self.abs_image_path)

    def get_data_for_ui_demo(self, index):
        # data = self.data[index]
        # data['file_name'] = os.path.basename(data['relative_path'])
        # data['index'] = index
        # return data
        image_path = self.abs_image_path[index]
        im = self.load_image(image_path, resize=False)

        data = {}
        data['file_name'] = os.path.basename(image_path)
        data['image'] = im
        data['captions'] = self.image_path_to_captions[image_path]
        return data

    def __getitem__(self, index):
        image_path = self.abs_image_path[index]
        im = self.load_image(image_path)
        caption = self.load_caption(image_path)
        if self.load_mask:
            mask = self.load_mask(im)
        
        # preprocess image and mask
        if self.preprocessor is not None:
            im = self.preprocessor(image=im)['image']

        data = {
            'image': np.transpose(im.astype(np.float32), (2, 0, 1)), # 3 x H x W
            'text': caption
        }

        if self.load_mask:
            h, w = im.shape[0:2]
            mask = Image.fromarray(mask[:, :, 0]).resize((w, h), resample=Image.NEAREST)
            mask = np.array(mask)[:, :, np.newaxis]
            data['mask'] = np.transpose(mask.astype(np.bool), (2, 0, 1)) # 1 x H x W

        if self.inferior_size is not None:
            inferior = self.load_inferior(im)
            data['inferior'] = np.transpose(inferior.astype(np.float32), (2, 0, 1)) # 3 x H x W

        return data
